{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "多组单智能体.\n",
    "\n",
    "critic共享state,但是actor不共享state."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gym\n",
    "\n",
    "\n",
    "#定义环境\n",
    "class MyWrapper(gym.Wrapper):\n",
    "\n",
    "    def __init__(self):\n",
    "        from pettingzoo.mpe import simple_tag_v3\n",
    "        env = simple_tag_v3.env(num_good=1,\n",
    "                                num_adversaries=1,\n",
    "                                num_obstacles=1,\n",
    "                                max_cycles=1e8,\n",
    "                                render_mode='rgb_array')\n",
    "        super().__init__(env)\n",
    "        self.env = env\n",
    "        self.step_n = 0\n",
    "\n",
    "    def reset(self):\n",
    "        self.env.reset()\n",
    "        self.step_n = 0\n",
    "        return self.state()\n",
    "\n",
    "    def state(self):\n",
    "        state = []\n",
    "        for i in self.env.agents:\n",
    "            state.append(env.observe(i).tolist())\n",
    "        state[-1].extend([0.0, 0.0])\n",
    "        return state\n",
    "\n",
    "    def step(self, action):\n",
    "        reward_sum = [0, 0]\n",
    "        for i in range(5):\n",
    "            if i != 0:\n",
    "                action = [-1, -1]\n",
    "            next_state, reward, over = self._step(action)\n",
    "            for j in range(2):\n",
    "                reward_sum[j] += reward[j]\n",
    "            self.step_n -= 1\n",
    "\n",
    "        self.step_n += 1\n",
    "\n",
    "        return next_state, reward_sum, over\n",
    "\n",
    "    def _step(self, action):\n",
    "        for i, _ in enumerate(env.agent_iter(2)):\n",
    "            self.env.step(action[i] + 1)\n",
    "\n",
    "        reward = [self.env.rewards[i] for i in self.env.agents]\n",
    "\n",
    "        _, _, termination, truncation, _ = env.last()\n",
    "        over = termination or truncation\n",
    "\n",
    "        #限制最大步数\n",
    "        self.step_n += 1\n",
    "        if self.step_n >= 100:\n",
    "            over = True\n",
    "\n",
    "        return self.state(), reward, over\n",
    "\n",
    "    #打印游戏图像\n",
    "    def show(self):\n",
    "        from matplotlib import pyplot as plt\n",
    "        plt.figure(figsize=(3, 3))\n",
    "        plt.imshow(self.env.render())\n",
    "        plt.show()\n",
    "\n",
    "\n",
    "env = MyWrapper()\n",
    "env.reset()\n",
    "\n",
    "env.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "\n",
    "class A2C:\n",
    "\n",
    "    def __init__(self, model_actor, model_critic, model_critic_delay,\n",
    "                 optimizer_actor, optimizer_critic):\n",
    "        self.model_actor = model_actor\n",
    "        self.model_critic = model_critic\n",
    "        self.model_critic_delay = model_critic_delay\n",
    "        self.optimizer_actor = optimizer_actor\n",
    "        self.optimizer_critic = optimizer_critic\n",
    "\n",
    "        self.model_critic_delay.load_state_dict(self.model_critic.state_dict())\n",
    "        self.requires_grad(self.model_critic_delay, False)\n",
    "\n",
    "    def soft_update(self, _from, _to):\n",
    "        for _from, _to in zip(_from.parameters(), _to.parameters()):\n",
    "            value = _to.data * 0.99 + _from.data * 0.01\n",
    "            _to.data.copy_(value)\n",
    "\n",
    "    def requires_grad(self, model, value):\n",
    "        for param in model.parameters():\n",
    "            param.requires_grad_(value)\n",
    "\n",
    "    def train_critic(self, state, reward, next_state, over):\n",
    "        self.requires_grad(self.model_critic, True)\n",
    "        self.requires_grad(self.model_actor, False)\n",
    "\n",
    "        #计算values和targets\n",
    "        value = self.model_critic(state)\n",
    "\n",
    "        with torch.no_grad():\n",
    "            target = self.model_critic_delay(next_state)\n",
    "        target = target * 0.99 * (1 - over) + reward\n",
    "\n",
    "        #时序差分误差,也就是tdloss\n",
    "        loss = torch.nn.functional.mse_loss(value, target)\n",
    "\n",
    "        loss.backward()\n",
    "        self.optimizer_critic.step()\n",
    "        self.optimizer_critic.zero_grad()\n",
    "        self.soft_update(self.model_critic, self.model_critic_delay)\n",
    "\n",
    "        #减去value相当于去基线\n",
    "        return (target - value).detach()\n",
    "\n",
    "    def train_actor(self, state, action, value):\n",
    "        self.requires_grad(self.model_critic, False)\n",
    "        self.requires_grad(self.model_actor, True)\n",
    "\n",
    "        #重新计算动作的概率\n",
    "        prob = self.model_actor(state)\n",
    "        prob = prob.gather(dim=1, index=action)\n",
    "\n",
    "        #根据策略梯度算法的导函数实现\n",
    "        #函数中的Q(state,action),这里使用critic模型估算\n",
    "        prob = (prob + 1e-8).log() * value\n",
    "        loss = -prob.mean()\n",
    "\n",
    "        loss.backward()\n",
    "        self.optimizer_actor.step()\n",
    "        self.optimizer_actor.zero_grad()\n",
    "\n",
    "        return loss.item()\n",
    "\n",
    "\n",
    "model_actor = [\n",
    "    torch.nn.Sequential(\n",
    "        torch.nn.Linear(10, 64),\n",
    "        torch.nn.ReLU(),\n",
    "        torch.nn.Linear(64, 64),\n",
    "        torch.nn.ReLU(),\n",
    "        torch.nn.Linear(64, 4),\n",
    "        torch.nn.Softmax(dim=1),\n",
    "    ) for _ in range(2)\n",
    "]\n",
    "\n",
    "model_critic = [\n",
    "    torch.nn.Sequential(\n",
    "        torch.nn.Linear(20, 64),\n",
    "        torch.nn.ReLU(),\n",
    "        torch.nn.Linear(64, 64),\n",
    "        torch.nn.ReLU(),\n",
    "        torch.nn.Linear(64, 1),\n",
    "    ) for _ in range(2)\n",
    "]\n",
    "\n",
    "model_critic_delay = [\n",
    "    torch.nn.Sequential(\n",
    "        torch.nn.Linear(20, 64),\n",
    "        torch.nn.ReLU(),\n",
    "        torch.nn.Linear(64, 64),\n",
    "        torch.nn.ReLU(),\n",
    "        torch.nn.Linear(64, 1),\n",
    "    ) for _ in range(2)\n",
    "]\n",
    "\n",
    "optimizer_actor = [\n",
    "    torch.optim.Adam(model_actor[i].parameters(), lr=1e-3) for i in range(2)\n",
    "]\n",
    "\n",
    "optimizer_critic = [\n",
    "    torch.optim.Adam(model_critic[i].parameters(), lr=5e-3) for i in range(2)\n",
    "]\n",
    "\n",
    "a2c = [\n",
    "    A2C(model_actor[i], model_critic[i], model_critic_delay[i],\n",
    "        optimizer_actor[i], optimizer_critic[i]) for i in range(2)\n",
    "]\n",
    "\n",
    "model_actor = None\n",
    "model_critic = None\n",
    "model_critic_delay = None\n",
    "optimizer_actor = None\n",
    "optimizer_critic = None\n",
    "\n",
    "a2c"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "from IPython import display\n",
    "import random\n",
    "\n",
    "\n",
    "#玩一局游戏并记录数据\n",
    "def play(show=False):\n",
    "    state = []\n",
    "    action = []\n",
    "    reward = []\n",
    "    next_state = []\n",
    "    over = []\n",
    "\n",
    "    s = env.reset()\n",
    "    o = False\n",
    "    while not o:\n",
    "        a = []\n",
    "        for i in range(2):\n",
    "            #计算动作\n",
    "            prob = a2c[i].model_actor(torch.FloatTensor(s[i]).reshape(\n",
    "                1, -1))[0].tolist()\n",
    "            a.append(random.choices(range(4), weights=prob, k=1)[0])\n",
    "\n",
    "        #执行动作\n",
    "        ns, r, o = env.step(a)\n",
    "\n",
    "        state.append(s)\n",
    "        action.append(a)\n",
    "        reward.append(r)\n",
    "        next_state.append(ns)\n",
    "        over.append(o)\n",
    "\n",
    "        s = ns\n",
    "\n",
    "        if show:\n",
    "            display.clear_output(wait=True)\n",
    "            env.show()\n",
    "\n",
    "    state = torch.FloatTensor(state)\n",
    "    action = torch.LongTensor(action).unsqueeze(-1)\n",
    "    reward = torch.FloatTensor(reward).unsqueeze(-1)\n",
    "    next_state = torch.FloatTensor(next_state)\n",
    "    over = torch.LongTensor(over).reshape(-1, 1)\n",
    "\n",
    "    return state, action, reward, next_state, over, reward.sum(\n",
    "        dim=0).flatten().tolist()\n",
    "\n",
    "\n",
    "state, action, reward, next_state, over, reward_sum = play()\n",
    "\n",
    "reward_sum"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "def train():\n",
    "    #训练N局\n",
    "    for epoch in range(500):\n",
    "        state, action, reward, next_state, over, _ = play()\n",
    "\n",
    "        #合并部分字段\n",
    "        state_c = state.flatten(start_dim=1)\n",
    "        next_state_c = next_state.flatten(start_dim=1)\n",
    "\n",
    "        for i in range(2):\n",
    "            value = a2c[i].train_critic(state_c, reward[:, i], next_state_c,\n",
    "                                        over)\n",
    "            loss = a2c[i].train_actor(state[:, i], action[:, i], value)\n",
    "\n",
    "        if epoch % 2500 == 0:\n",
    "            test_result = [play()[-1] for _ in range(20)]\n",
    "            test_result = torch.FloatTensor(test_result).mean(dim=0).tolist()\n",
    "            print(epoch, loss, test_result)\n",
    "\n",
    "\n",
    "train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAR8AAAEYCAYAAABlUvL1AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAlgElEQVR4nO3df3RTdZ438HfS/KA/SGILTehji50Bt1QK1hbaiDOjUqnYnRGpPg5TpcuystSAQmfQ7TmKO7haDuyMMzoIjmdH2BVBOQ4qLIilaBmWUKDIWlqsgDipQFqk06Tt0LRNPs8fPr1j5FfTX7ex79c59xx6v9+b+7nX3Lc395ubqxERARHRINOqXQARDU8MHyJSBcOHiFTB8CEiVTB8iEgVDB8iUgXDh4hUwfAhIlUwfIhIFQwfIlKFauGzZs0a3HDDDRgxYgSysrJw8OBBtUohIhWoEj5vvvkmiouL8cwzz+DIkSOYPHkycnNz0djYqEY5RKQCjRo3lmZlZWHKlCn43e9+BwAIBAJITEzE4sWL8S//8i+DXQ4RqUA32Cvs6OhAVVUVSkpKlHlarRY5OTlwOp2XXcbn88Hn8yl/BwIBNDU1IS4uDhqNZsBrJqKeERG0tLQgISEBWu3VP1gNevh89dVX8Pv9sFqtQfOtVis+/fTTyy5TWlqKX/7yl4NRHhH1g/r6elx//fVX7TPo4dMbJSUlKC4uVv72eDxISkpCfX09TCaTipUR0Td5vV4kJiZi5MiR1+w76OEzatQoREREoKGhIWh+Q0MDbDbbZZcxGo0wGo2XzDeZTAwfoiGoJ5dDBn20y2AwICMjA+Xl5cq8QCCA8vJy2O32wS6HiFSiyseu4uJiFBYWIjMzE1OnTsVvfvMbtLW1Yd68eWqUQ0QqUCV8HnzwQZw/fx7Lly+H2+3GzTffjPfff/+Si9BE9N2lyvd8+srr9cJsNsPj8fCaD9EQEsqxyXu7iEgVDB8iUgXDh4hUwfAhIlUwfIhIFQwfIlJFWNzbRUSDKxAIoLW1FV+6XIAINFotxt5wAyKjovrtlyQYPkSk8Pv9OHbkCA5/8AG+qqqCoanp6waNBp2jR8M2dSqyZs7EjampfQ4hfsmQiAAAba2t+GjrVpz+r/9CnAgiNJqggBERdIngK50OaUVFuHXGDBhHjAh6DX7JkIhC0tbWhreefx7n1q9HPACdVnvJmY1Go4Feq4XN78dnv/0ttr70Ejo6Onq9ToYP0TDn9/tRsXUr/JWVMOt01/w4pdFoEKvT4S+7dqGyvBy9/fDE8CEa5o59/DFO/ed/whQREdJycVotDr/0Ev78+ee9Wi/Dh2gYCwQCqPrgA8SJhHwBWaPRINbnw+GPPurV2Q/Dh2gYa2ttxfmqKuh6OXIVGRGBP+/bh87OzpCXZfgQDWNnvvwShgsXej1srtFo0FVfj6buIfkQMHyIhrH++p4NP3YRUUi0Gg2kr99Y1mqv+Yyuyy7Wt7USUThLGjsWXVZrr4fLRQSG730PcXFxIS/L8CEaxkZERmLM1KnoDAR6tXxbVxfG/fCHiAhxmB5g+BANaxqNBtkzZ+IrgyHksx8RQbPFgszbb+/VBWuGD9EwNy4lBTc7HGjy+0NarkEEdzzxBGxjxvRqvQwfomFOo9EgOycHsXl5OO/3X/MMSETgFkHyz36GtClTej1MH3L47N27Fz/+8Y+RkJAAjUaDd95555LCli9fjjFjxiAyMhI5OTk4ceJEUJ+mpiYUFBTAZDLBYrFg/vz5aG1t7dUGEFHfGY1GzFq0CJOfeAJn9Hq0dXZeEkIigpbOTpyJicEPnn0WM+fO7dW1nm4hh09bWxsmT56MNWvWXLZ91apVePHFF7Fu3TpUVlYiOjoaubm5aG9vV/oUFBSgpqYGZWVl2L59O/bu3YsFCxb0eiOIqO8MBgNuvesuzP7d72CeOxf1iYk4FQjglN+PUyI48/3vI37BAsxZswaTp07tU/AAAKQPAMjWrVuVvwOBgNhsNlm9erUyr7m5WYxGo2zatElERGprawWAHDp0SOmzc+dO0Wg0cubMmR6t1+PxCADxeDx9KZ+IriAQCIjP5xO32y3nzp2TBrdbOjs7JRAIXHW5UI7Nfr3mc/r0abjdbuTk5CjzzGYzsrKy4HQ6AQBOpxMWiwWZmZlKn5ycHGi1WlRWVl72dX0+H7xeb9BERANHo9HAYDDAarXCZrMh3mqFrgc/txGKfg0ft9sNAJc8c91qtSptbrcb8fHxQe06nQ6xsbFKn28rLS2F2WxWpsTExP4sm4hUEBajXSUlJfB4PMpUX1+vdklE1Ef9Gj42mw0A0NDQEDS/oaFBabPZbGhsbAxq7+rqQlNTk9Ln24xGI0wmU9BEROGtX8MnOTkZNpsN5eXlyjyv14vKykrY7XYAgN1uR3NzM6qqqpQ+e/bsQSAQQFZWVn+WQ0RDWMiPzmltbcXJkyeVv0+fPo2jR48iNjYWSUlJWLJkCf7t3/4N48ePR3JyMp5++mkkJCRg1qxZAIAJEybg7rvvxiOPPIJ169ahs7MTixYtwk9/+lMkJCT024YR0RAX6hDchx9+KPj6Z0CCpsLCQmWI7umnnxar1SpGo1GmT58udXV1Qa9x4cIFmTNnjsTExIjJZJJ58+ZJS0tLj2vgUDvR0BTKscnndhFRv+Fzu4hoyGP4EJEqGD5EpAqGDxGpguFDRKpg+BCRKhg+RKQKhg8RqYLhQ0SqYPgQkSoYPkSkCoYPEamC4UNEqmD4EJEqGD5EpAqGDxGpguFDRKpg+BCRKhg+RKQKhg8RqYLhQ0SqYPgQkSpCCp/S0lJMmTIFI0eORHx8PGbNmoW6urqgPu3t7XA4HIiLi0NMTAzy8/MveXyyy+VCXl4eoqKiEB8fj2XLlqGrq6vvW0NEYSOk8KmoqIDD4cCBAwdQVlaGzs5OzJgxA21tbUqfpUuXYtu2bdiyZQsqKipw9uxZzJ49W2n3+/3Iy8tDR0cH9u/fjw0bNmD9+vVYvnx5/20VEQ19fXk6YWNjowCQiooKERFpbm4WvV4vW7ZsUfocP35cAIjT6RQRkR07dohWqxW32630Wbt2rZhMJvH5fD1aL59YSjQ0hXJs9umaj8fjAQDExsYCAKqqqtDZ2YmcnBylT0pKCpKSkuB0OgEATqcTaWlpsFqtSp/c3Fx4vV7U1NRcdj0+nw9erzdoIqLw1uvwCQQCWLJkCaZNm4aJEycCANxuNwwGAywWS1Bfq9UKt9ut9Plm8HS3d7ddTmlpKcxmszIlJib2tmwiGiJ6HT4OhwPHjh3D5s2b+7OeyyopKYHH41Gm+vr6AV8nEQ0sXW8WWrRoEbZv3469e/fi+uuvV+bbbDZ0dHSgubk56OynoaEBNptN6XPw4MGg1+seDevu821GoxFGo7E3pRLREBXSmY+IYNGiRdi6dSv27NmD5OTkoPaMjAzo9XqUl5cr8+rq6uByuWC32wEAdrsd1dXVaGxsVPqUlZXBZDIhNTW1L9tCRGEkpDMfh8OBN954A++++y5GjhypXKMxm82IjIyE2WzG/PnzUVxcjNjYWJhMJixevBh2ux3Z2dkAgBkzZiA1NRUPP/wwVq1aBbfbjaeeegoOh4NnN0TDSSjDaAAuO7322mtKn4sXL8qjjz4q1113nURFRcl9990n586dC3qdL774QmbOnCmRkZEyatQo+fnPfy6dnZ09roND7URDUyjHpkZERL3o6x2v1wuz2QyPxwOTyaR2OUT0/4VybPLeLiJSBcOHiFTB8CEiVTB8iEgVDB8iUgXDh4hUwfAhIlUwfIhIFQwfIlIFw4eIVMHwISJVMHyISBUMHyJSBcOHiFTB8CEiVTB8iEgVvfoBeaL+5vf78de//hXA1w8MMBgMKldEA43hQ6ry+/2oO1WHTRWb8PH5jyEQJEYnYk7WHGSnZ/N3vb/D+DOqpJquri5s2LEBb37+JiReoNFpoNFoIAGB/4IfU41TUfJ/S2Aayf/G4YI/o0pDnohg6+6tePP8m0ACoNVrodFoAAAarQa60TocHnEY//7Wv8Pv96tcLQ0Ehg+pwuPxYEvNFmAklND5Nq1BC2ebE9XHqwe5OhoMDB9SxQf7P8AFy4UrBo9iNPDmn95EGF4doGsIKXzWrl2LSZMmwWQywWQywW63Y+fOnUp7e3s7HA4H4uLiEBMTg/z8fOVRyN1cLhfy8vIQFRWF+Ph4LFu2DF1dXf2zNRQ23Bfc0EReI3gAaHVaNLQ0MHy+g0IKn+uvvx4rV65EVVUVDh8+jDvvvBP33nsvampqAABLly7Ftm3bsGXLFlRUVODs2bOYPXu2srzf70deXh46Ojqwf/9+bNiwAevXr8fy5cv7d6toyNPr9EDg2v1EBPoI/cAXRIOuz6NdsbGxWL16Ne6//36MHj0ab7zxBu6//34AwKeffooJEybA6XQiOzsbO3fuxN///d/j7NmzsFqtAIB169bhySefxPnz53v83Q6OdoW/z059hsfeewxy/dXffl3NXVh0wyLMypl17Y9opLpBGe3y+/3YvHkz2traYLfbUVVVhc7OTuTk5Ch9UlJSkJSUBKfTCQBwOp1IS0tTggcAcnNz4fV6lbOny/H5fPB6vUEThbdxyeMwzTwNgc4rn/6ICGxeG+6y38Xg+Q4KOXyqq6sRExMDo9GIhQsXYuvWrUhNTYXb7YbBYIDFYgnqb7Va4Xa7AQButzsoeLrbu9uupLS0FGazWZkSExNDLZuGGK1Wi3/+yT/j+y3fR9fFrqBrOiKCQFcA0Y3R+MXMXyA6OlrFSmmghBw+f/d3f4ejR4+isrISRUVFKCwsRG1t7UDUpigpKYHH41Gm+vr6AV0fDY74UfF49v5nMUs/C8Z6I3x/8cH3Fx/krGBqy1SsvHslbp5wM896vqNCvr3CYDBg3LhxAICMjAwcOnQIv/3tb/Hggw+io6MDzc3NQWc/DQ0NsNlsAACbzYaDBw8GvV73aFh3n8sxGo38mv131Oi40Vj000WYfW42Pvn0EwgEYxPGImV8CiIiItQujwZQn7/nEwgE4PP5kJGRAb1ej/LycqWtrq4OLpcLdrsdAGC321FdXY3GxkalT1lZGUwmE1JTU/taCoUpjUaD/5PwfzDzzpm45857cFPKTQyeYSCkM5+SkhLMnDkTSUlJaGlpwRtvvIGPPvoIu3btgtlsxvz581FcXIzY2FiYTCYsXrwYdrsd2dnZAIAZM2YgNTUVDz/8MFatWgW3242nnnoKDoeDZzZEw0xI4dPY2Ii5c+fi3LlzMJvNmDRpEnbt2oW77roLAPDCCy9Aq9UiPz8fPp8Pubm5ePnll5XlIyIisH37dhQVFcFutyM6OhqFhYVYsWJF/24VEQ15vKudiPoN72onoiGP4UNEqmD4EJEqGD5EpAqGDxGpguFDRKrg0ytoUIgI/H4/vvrqK5w/fx6nT5/GyZMng/qYzWZMmTIFUVFRSEpKgl6v531d32EMHxowIoKuri588cUX2Lt3Lz788EPU19fD4/EAuPxvN4sIDAYDxowZg+zsbEyfPh033XQToqKiGETfMfySIfU7EcHFixfx4Ycf4o033sCJEyfQ2dkJjUbT4wAREeVnNhISEnDPPffg/vvvR3x8PENoCAvl2GT4UL/y+/3Yu3cvXnnlFdTV1YUUOFfSHUSxsbF44IEHMHfuXERGRjKEhiCGDw06EcH58+fxyiuvYPv27cqZTn+vAwBuuukmPPHEE5g4cSIDaIjh7RU0qEQE1dXVcDgc+OMf/4iurq4BCYXus6iamhosXLgQ77zzDnw+X7+vhwYHw4f6RETwwQcfYOHChTh16hS02oF/S2k0Gly8eBErVqzAr371K7S3tw/4Oqn/cbSLeq07eFasWIH29vZB/QjUfRb09ttvAwCKi4sxYsSIQVs/9R3PfKjXdu/ejRUrVuDixYuqXnt5++238etf/xqdnZ2q1UChY/hQr7hcLvzmN79RPXi6/fGPf8Tu3bv5ZNMwwvChkLW3t+O5557DuXPnhkTwAF//lnhpaSlOnz6tdinUQwwfCkkgEMDbb7+NgwcPDpngAb6+BtTa2orVq1fzAnSYYPhQSFwuF9atWzekguebDhw4gPfee0/tMqgHGD7UY4FAAJs2bUJra+uQDJ/uEbCNGzfykdphgOFDPeZyubBjx44hGTzdNBoNXC4Xdu3apXYpdA0MH+qxvXv3wuv1DunwAb4OoJ07d3LofYjrU/isXLkSGo0GS5YsUea1t7fD4XAgLi4OMTExyM/PVx6J3M3lciEvLw9RUVGIj4/HsmXL0NXV1ZdSaIB1dHRgz549YfMk0draWpw5c0btMugqeh0+hw4dwiuvvIJJkyYFzV+6dCm2bduGLVu2oKKiAmfPnsXs2bOVdr/fj7y8PHR0dGD//v3YsGED1q9fj+XLl/d+K2jAffnllzh+/PiQP+sB/nb7xb59+9Quha6iV+HT2tqKgoICvPrqq7juuuuU+R6PB//xH/+BX//617jzzjuRkZGB1157Dfv378eBAwcAAB988AFqa2vx+uuv4+abb8bMmTPx7LPPYs2aNejo6OifraJ+V1tbi4sXL6pdRo9FRETg0KFDCAQCapdCV9Cr8HE4HMjLy0NOTk7Q/KqqKnR2dgbNT0lJQVJSEpxOJwDA6XQiLS0NVqtV6ZObmwuv14uamprLrs/n88Hr9QZNNLj27dsHnS68bgWsrq5GW1ub2mXQFYT8btq8eTOOHDmCQ4cOXdLmdrthMBhgsViC5lutVrjdbqXPN4Onu7277XJKS0vxy1/+MtRSqZ8EAoGwO4g1Gg06Ojrg8/kwcuRItcuhywjpzKe+vh6PP/44Nm7cOKh3EJeUlMDj8ShTfX39oK2bvv6YfezYMbXLCFlLSwtqa2vVLoOuIKTwqaqqQmNjI2655RbodDrodDpUVFTgxRdfhE6ng9VqRUdHB5qbm4OWa2hogM1mAwDYbLZLRr+6/+7u821GoxEmkyloosHT/UPw4XCx+Zu666ahKaTwmT59Oqqrq3H06FFlyszMREFBgfJvvV6P8vJyZZm6ujq4XC7Y7XYAgN1uR3V1NRobG5U+ZWVlMJlMSE1N7afNIqKhLqRrPiNHjsTEiROD5kVHRyMuLk6ZP3/+fBQXFyM2NhYmkwmLFy+G3W5HdnY2AGDGjBlITU3Fww8/jFWrVsHtduOpp56Cw+GA0Wjsp80ioqGu34cvXnjhBWi1WuTn58Pn8yE3Nxcvv/yy0h4REYHt27ejqKgIdrsd0dHRKCwsxIoVK/q7FOpnIhJ2H71o6OLTK+iaOjs7UVRUhCNHjoRV+BgMBrz11ltITExUu5Rhg0+voH6l1+vDLuRFJCzrHk4YPtQjU6dOhd/vV7uMkKSkpCA6OlrtMugKGD7UIxMmTIBer1e7jB4LBAJIS0sLu29lDycMH+qRG2+8EWPHjg2bH2jX6XS4/fbb1S6DroLhQz0SGRmJH/7wh2Fxo6aIIDk5GePGjVO7FLoKhg/12B133AG9Xj/kz34CgQB+9KMfITIyUu1S6CoYPtRjqampmDZtmtplXJWIIDY2Fvn5+WqXQtfA8KEe0+l0KCwsRERExJA9+wkEApg1a9YV7xOkoYPhQyFJS0tDfn7+kAwfEcG4cePws5/9LKy+DDlcMXwoJDqdDv/0T/+E5OTkIRVAIgKdToelS5di1KhRapdDPcDwoZCNGjUKy5YtQ0xMzJAKoLlz5yIrK0vtMqiHGD7UK9nZ2XjooYeg1ar/FhIR3HLLLZg3bx6/VBhG1H/nUFjSaDSYN28eCgsLVb2+IiJIT0/HypUreStFmGH4UK/p9XosXLgQhYWFADDoH8ECgYASPLzOE354jkp90h1Ao0ePxrp16wbliaYiAq1Wi/vuuw+PPvoogydMMXyoz/R6PR588EHceOONeOGFF1BTUzNgASQisFgsWLhwIe67776wutmVgvFjF/ULjUaD9PR0rFmzBv/4j/+IqKiofr0PrPtXFG+99Vb84Q9/wAMPPMDgCXP8JUPqdyKCP//5z9i4cSPef/99tLa2QqPRhHw2JCIIBALQ6/XIyMhAYWEhMjMzodPp+CXCISqUY5PhQwNGROB2u3HgwAHs3r0bn3zyCdra2hAIBKDRaC4Zpu8OG41GA51Oh3HjxuGOO+7A7bffjuTkZIZOGGD40JAiIhARnD9/HidPnsSpU6dw+vRpnDx5Mqif2WzGlClTYLFYkJ6eDpvNBr1ez8AJI6Ecm7zgTAOu+yOX1WqF1WrFtGnTrjosz7AZHhg+pAoGDIU02vWv//qvyv/FuqeUlBSlvb29HQ6HA3FxcYiJiUF+fv4lj0Z2uVzIy8tDVFQU4uPjsWzZMj7SlmgYCvnM56abbsLu3bv/9gLfuJdm6dKl+O///m9s2bIFZrMZixYtwuzZs/E///M/AAC/34+8vDzYbDbs378f586dw9y5c6HX6/H888/3w+YQUdiQEDzzzDMyefLky7Y1NzeLXq+XLVu2KPOOHz8uAMTpdIqIyI4dO0Sr1Yrb7Vb6rF27Vkwmk/h8vh7X4fF4BIB4PJ5QyieiARbKsRnylwxPnDiBhIQEfO9730NBQQFcLhcAoKqqCp2dncjJyVH6pqSkICkpCU6nEwDgdDqRlpYGq9Wq9MnNzYXX60VNTc0V1+nz+eD1eoMmIgpvIYVPVlYW1q9fj/fffx9r167F6dOn8YMf/AAtLS1wu90wGAywWCxBy1itVrjdbgCA2+0OCp7u9u62KyktLYXZbFYmPv6WKPyFdM1n5syZyr8nTZqErKwsjB07Fm+99daAPimgpKQExcXFyt9er5cBRBTm+nRvl8ViwY033oiTJ0/CZrOho6MDzc3NQX0aGhqUH/O22WyXjH51/321H/w2Go0wmUxBExGFtz6FT2trK06dOoUxY8YgIyMDer0e5eXlSntdXR1cLhfsdjsAwG63o7q6Go2NjUqfsrIymEwmpKam9qUUIgozIX3s+sUvfoEf//jHGDt2LM6ePYtnnnkGERERmDNnDsxmM+bPn4/i4mLExsbCZDJh8eLFsNvtyM7OBgDMmDEDqampePjhh7Fq1Sq43W489dRTcDgcMBqNA7KBRDQ0hRQ+X375JebMmYMLFy5g9OjRuO2223DgwAGMHj0aAPDCCy9Aq9UiPz8fPp8Pubm5ePnll5XlIyIisH37dhQVFcFutyM6OhqFhYVYsWJF/24VEQ15vLGUiPpNKMcmf0yMiFTB8CEiVTB8iEgVDB8iUgXDh4hUwfAhIlUwfIhIFQwfIlIFw4eIVMHwISJVMHyISBUMHyJSBcOHiFTB8CEiVTB8iEgVDB8iUgXDh4hUwfAhIlUwfIhIFQwfIlIFw4eIVMHwISJVhBw+Z86cwUMPPYS4uDhERkYiLS0Nhw8fVtpFBMuXL8eYMWMQGRmJnJwcnDhxIug1mpqaUFBQAJPJBIvFgvnz56O1tbXvW0NEYSOk8PnLX/6CadOmQa/XY+fOnaitrcWvfvUrXHfddUqfVatW4cUXX8S6detQWVmJ6Oho5Obmor29XelTUFCAmpoalJWVYfv27di7dy8WLFjQf1tFREOfhODJJ5+U22677YrtgUBAbDabrF69WpnX3NwsRqNRNm3aJCIitbW1AkAOHTqk9Nm5c6doNBo5c+ZMj+rweDwCQDweTyjlE9EAC+XYDOnM57333kNmZiYeeOABxMfHIz09Ha+++qrSfvr0abjdbuTk5CjzzGYzsrKy4HQ6AQBOpxMWiwWZmZlKn5ycHGi1WlRWVl52vT6fD16vN2giovAWUvh8/vnnWLt2LcaPH49du3ahqKgIjz32GDZs2AAAcLvdAACr1Rq0nNVqVdrcbjfi4+OD2nU6HWJjY5U+31ZaWgqz2axMiYmJoZRNRENQSOETCARwyy234Pnnn0d6ejoWLFiARx55BOvWrRuo+gAAJSUl8Hg8ylRfXz+g6yOigRdS+IwZMwapqalB8yZMmACXywUAsNlsAICGhoagPg0NDUqbzWZDY2NjUHtXVxeampqUPt9mNBphMpmCJiIKbyGFz7Rp01BXVxc077PPPsPYsWMBAMnJybDZbCgvL1favV4vKisrYbfbAQB2ux3Nzc2oqqpS+uzZsweBQABZWVm93hAiCjOhXMk+ePCg6HQ6ee655+TEiROyceNGiYqKktdff13ps3LlSrFYLPLuu+/KJ598Ivfee68kJyfLxYsXlT533323pKenS2Vlpezbt0/Gjx8vc+bMGZAr6kQ0eEI5NkMKHxGRbdu2ycSJE8VoNEpKSor8/ve/D2oPBALy9NNPi9VqFaPRKNOnT5e6urqgPhcuXJA5c+ZITEyMmEwmmTdvnrS0tPS4BoYP0dAUyrGpERFR99wrdF6vF2azGR6Ph9d/iIaQUI5N3ttFRKpg+BCRKhg+RKQKhg8RqYLhQ0SqYPgQkSoYPkSkCoYPEamC4UNEqmD4EJEqGD5EpAqGDxGpguFDRKpg+BCRKhg+RKQKhg8RqYLhQ0SqYPgQkSoYPkSkCoYPEamC4UNEqmD4EJEqGD5EpAqGDxGpguFDRKrQqV1Ab3Q/ZNXr9apcCRF9U/cx2ZMHIYdl+Fy4cAEAkJiYqHIlRHQ5LS0tMJvNV+0TluETGxsLAHC5XNfcwO8yr9eLxMRE1NfXD+tn1nM/fG0o7AcRQUtLCxISEq7ZNyzDR6v9+lKV2Wwe1m+2biaTifsB3A/d1N4PPT0h4AVnIlIFw4eIVBGW4WM0GvHMM8/AaDSqXYqquB++xv3wtXDbDxrpyZgYEVE/C8szHyIKfwwfIlIFw4eIVMHwISJVhGX4rFmzBjfccANGjBiBrKwsHDx4UO2S+k1paSmmTJmCkSNHIj4+HrNmzUJdXV1Qn/b2djgcDsTFxSEmJgb5+floaGgI6uNyuZCXl4eoqCjEx8dj2bJl6OrqGsxN6VcrV66ERqPBkiVLlHnDZT+cOXMGDz30EOLi4hAZGYm0tDQcPnxYaRcRLF++HGPGjEFkZCRycnJw4sSJoNdoampCQUEBTCYTLBYL5s+fj9bW1sHelGASZjZv3iwGg0H+8Ic/SE1NjTzyyCNisVikoaFB7dL6RW5urrz22mty7NgxOXr0qNxzzz2SlJQkra2tSp+FCxdKYmKilJeXy+HDhyU7O1tuvfVWpb2rq0smTpwoOTk58vHHH8uOHTtk1KhRUlJSosYm9dnBgwflhhtukEmTJsnjjz+uzB8O+6GpqUnGjh0r//AP/yCVlZXy+eefy65du+TkyZNKn5UrV4rZbJZ33nlH/vd//1d+8pOfSHJysly8eFHpc/fdd8vkyZPlwIED8qc//UnGjRsnc+bMUWOTFGEXPlOnThWHw6H87ff7JSEhQUpLS1WsauA0NjYKAKmoqBARkebmZtHr9bJlyxalz/HjxwWAOJ1OERHZsWOHaLVacbvdSp+1a9eKyWQSn883uBvQRy0tLTJ+/HgpKyuTH/3oR0r4DJf98OSTT8ptt912xfZAICA2m01Wr16tzGtubhaj0SibNm0SEZHa2loBIIcOHVL67Ny5UzQajZw5c2bgir+GsPrY1dHRgaqqKuTk5CjztFotcnJy4HQ6Vaxs4Hg8HgB/u5m2qqoKnZ2dQfsgJSUFSUlJyj5wOp1IS0uD1WpV+uTm5sLr9aKmpmYQq+87h8OBvLy8oO0Fhs9+eO+995CZmYkHHngA8fHxSE9Px6uvvqq0nz59Gm63O2g/mM1mZGVlBe0Hi8WCzMxMpU9OTg60Wi0qKysHb2O+JazC56uvvoLf7w96MwGA1WqF2+1WqaqBEwgEsGTJEkybNg0TJ04EALjdbhgMBlgslqC+39wHbrf7svuouy1cbN68GUeOHEFpaeklbcNlP3z++edYu3Ytxo8fj127dqGoqAiPPfYYNmzYAOBv23G1Y8LtdiM+Pj6oXafTITY2VtX9EJZ3tQ8XDocDx44dw759+9QuZdDV19fj8ccfR1lZGUaMGKF2OaoJBALIzMzE888/DwBIT0/HsWPHsG7dOhQWFqpcXd+E1ZnPqFGjEBERccmIRkNDA2w2m0pVDYxFixZh+/bt+PDDD3H99dcr8202Gzo6OtDc3BzU/5v7wGazXXYfdbeFg6qqKjQ2NuKWW26BTqeDTqdDRUUFXnzxReh0Olit1mGxH8aMGYPU1NSgeRMmTIDL5QLwt+242jFhs9nQ2NgY1N7V1YWmpiZV90NYhY/BYEBGRgbKy8uVeYFAAOXl5bDb7SpW1n9EBIsWLcLWrVuxZ88eJCcnB7VnZGRAr9cH7YO6ujq4XC5lH9jtdlRXVwe94crKymAymS55Iw9V06dPR3V1NY4ePapMmZmZKCgoUP49HPbDtGnTLvmqxWeffYaxY8cCAJKTk2Gz2YL2g9frRWVlZdB+aG5uRlVVldJnz549CAQCyMrKGoStuALVLnX30ubNm8VoNMr69eultrZWFixYIBaLJWhEI5wVFRWJ2WyWjz76SM6dO6dMf/3rX5U+CxculKSkJNmzZ48cPnxY7Ha72O12pb17iHnGjBly9OhRef/992X06NFhNcR8Od8c7RIZHvvh4MGDotPp5LnnnpMTJ07Ixo0bJSoqSl5//XWlz8qVK8Visci7774rn3zyidx7772XHWpPT0+XyspK2bdvn4wfP55D7b3x0ksvSVJSkhgMBpk6daocOHBA7ZL6DYDLTq+99prS5+LFi/Loo4/KddddJ1FRUXLffffJuXPngl7niy++kJkzZ0pkZKSMGjVKfv7zn0tnZ+cgb03/+nb4DJf9sG3bNpk4caIYjUZJSUmR3//+90HtgUBAnn76abFarWI0GmX69OlSV1cX1OfChQsyZ84ciYmJEZPJJPPmzZOWlpbB3IxL8Cc1iEgVYXXNh4i+Oxg+RKQKhg8RqYLhQ0SqYPgQkSoYPkSkCoYPEamC4UNEqmD4EJEqGD5EpAqGDxGpguFDRKr4f7xQ9sHbGC34AAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 300x300 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "[30.0, -30.0]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "play(True)[-1]"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "第9章-策略梯度算法.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
